import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import KFold, train_test_split
from sample_split import split_sample_cv, balance_train_size
from util import eval_balance


def classification_onefold(df, Y_column, G_column, X_columns, clf_b, clf_r,
                           seed=42, ratio=1.0, verbose=True, test_size=0.33,
                           standardize=False):
    """Perform classification with one fold

    Parameters
    ----------
    G_column : str
        group variable
    X_columns : list of str
        covariates
    Y_column : str
        outcome variable
    df : pandas.DataFrame
        whole dataset
    seed : int, optional
        for randomization, by default 42
    ratio : float, optional
        ratio of the large group to the small group, by default 1.0
        use all data points if None
    clf_b, clf_r : sklearn.base.BaseEstimator, optional
        classifier from scikit-learn

    Returns
    -------
    res : dict
        a dictionary of the results
    X_test_blue, X_test_red : pandas.DataFrame
    y_test_blue, y_test_red : np.array
        test sets
    """
    np.random.seed(seed)
    train_idx, test_idx =\
        train_test_split(df.index, test_size=test_size, random_state=seed)
    X_test_blue, y_test_blue, X_test_red, y_test_red, train =\
        split_sample_cv(G_column, X_columns, Y_column, df, train_idx, test_idx,
                        standardize=standardize)
    
    X_train_blue_full, X_train_red_full, y_train_blue_full, y_train_red_full =\
        balance_train_size(train, G_column, X_columns, Y_column, ratio,
                           verbose=verbose)
    # blue
    clf_b.fit(X_train_blue_full, y_train_blue_full)
    res_blue = est_group_loss(clf_b, X_test_blue, y_test_blue,
                              X_test_red, y_test_red)
    e_b_B = res_blue['e_b']
    e_r_B = res_blue['e_r']

    # red
    clf_r.fit(X_train_red_full, y_train_red_full)
    res_red = est_group_loss(clf_r, X_test_blue, y_test_blue,
                             X_test_red, y_test_red)
    e_b_R = res_red['e_b']
    e_r_R = res_red['e_r']
    
    # group-balanced or not
    print(f'[blue] e_b={e_b_B.round(3)}, e_r={e_r_B.round(3)}, fairness={np.abs(e_b_B - e_r_B).round(3)}')
    print(f'[red] e_b={e_b_R.round(3)}, e_r={e_r_R.round(3)}, fairness={np.abs(e_b_R - e_r_R).round(3)}')
    
    if e_b_R < e_r_R:
        print('b-skewed')
    elif e_r_B < e_b_B:
        print('r-skewed')
    else:
        print('balanced')
    
    # results
    res  = {'clf_b': clf_b, 'clf_r': clf_r,
            'res_blue': res_blue, 'res_red': res_red}
    
    return res, X_test_blue, y_test_blue, X_test_red, y_test_red


def bs_p_value_45(clf_b, clf_r, X_test_blue, y_test_blue, X_test_red,
                  y_test_red, e_b_B, e_r_B, e_b_R, e_r_R,
                  delta=0.01, n_bs=1000, seed=42):
    """Bootstrap p-value for the difference between two classifiers

    Parameters
    ----------
    clf_b, clf_r : sklearn.base.BaseEstimator
        classifier from scikit-learn
    X_test_blue, y_test_blue, X_test_red, y_test_red : np.array
        test sets
        X's are pandas dataframes, y's are numpy arrays
    delta : float, optional
        margin for the difference, by default 0.01
    n_bs : int, optional
        number of bootstrap samples, by default 1000
    seed : int, optional
        for randomization, by default 42
    e_b_B, e_r_B, e_b_R, e_r_R : float
        estimated group losses

    Returns
    -------
    p_b, p_r : float
        p-values
    """
    np.random.seed(seed)
    n_blue = X_test_blue.shape[0]
    n_red = X_test_red.shape[0]
    e_b_B_list = []
    e_r_B_list = []
    e_b_R_list = []
    e_r_R_list = []
    X_test_blue.reset_index(drop=True, inplace=True)
    X_test_red.reset_index(drop=True, inplace=True)
    for _ in range(n_bs):
        if _ % 500 == 0:
            print(f'bootstrap sample {_}')
        idx_blue = np.random.choice(n_blue, n_blue, replace=True)
        idx_red = np.random.choice(n_red, n_red, replace=True)
        X_test_blue_bs = X_test_blue.iloc[idx_blue]
        y_test_blue_bs = y_test_blue[idx_blue]
        X_test_red_bs = X_test_red.iloc[idx_red]
        y_test_red_bs = y_test_red[idx_red]
        res_blue = est_group_loss(clf_b, X_test_blue_bs, y_test_blue_bs,
                                  X_test_red_bs, y_test_red_bs)
        e_b_B_list.append(res_blue['e_b'])
        e_r_B_list.append(res_blue['e_r'])
        res_red = est_group_loss(clf_r, X_test_blue_bs, y_test_blue_bs,
                                 X_test_red_bs, y_test_red_bs)
        e_b_R_list.append(res_red['e_b'])
        e_r_R_list.append(res_red['e_r'])
    e_b_B_list = np.array(e_b_B_list)
    e_r_B_list = np.array(e_r_B_list)
    e_b_R_list = np.array(e_b_R_list)
    e_r_R_list = np.array(e_r_R_list)
    np.save('e_b_B_list_{seed}_{n_bs}.npy', e_b_B_list)
    np.save('e_r_B_list_{seed}_{n_bs}.npy', e_r_B_list)
    np.save('e_b_R_list_{seed}_{n_bs}.npy', e_b_R_list)
    np.save('e_r_R_list_{seed}_{n_bs}.npy', e_r_R_list)
    
    # compute p-values
    b_diff_list = np.abs(e_b_R_list - e_r_R_list) - delta
    r_diff_list = np.abs(e_r_B_list - e_b_B_list) - delta
    b_diff = np.abs(e_b_R - e_r_R) - delta
    r_diff = np.abs(e_r_B - e_b_B) - delta
    
    p_b = np.mean(b_diff_list - b_diff <= b_diff)
    p_r = np.mean(r_diff_list - r_diff <= r_diff)
    p_max = np.max([p_b, p_r])
    
    print(f'p_b: {p_b}, p_r: {p_r}, p_max: {p_max}')

    return p_b, p_r


def bs_p_value_same_optimal(clf_b, clf_r, X_test_blue, y_test_blue, X_test_red,
                            y_test_red, e_b_B, e_r_B, e_b_R, e_r_R,
                            delta=0.005, n_bs=1000, seed=42):
    """Bootstrap p-value for the difference between two classifiers

    Parameters
    ----------
    clf_b, clf_r : sklearn.base.BaseEstimator
        classifier from scikit-learn
    X_test_blue, y_test_blue, X_test_red, y_test_red : np.array
        test sets
        X's are pandas dataframes, y's are numpy arrays
    delta : float, optional
        margin for the difference, by default 0.01
    n_bs : int, optional
        number of bootstrap samples, by default 1000
    seed : int, optional
        for randomization, by default 42
    e_b_B, e_r_B, e_b_R, e_r_R : float
        estimated group losses

    Returns
    -------
    p_b, p_r : float
        p-values
    """
    np.random.seed(seed)
    n_blue = X_test_blue.shape[0]
    n_red = X_test_red.shape[0]
    e_b_B_list = []
    e_r_B_list = []
    e_b_R_list = []
    e_r_R_list = []
    X_test_blue.reset_index(drop=True, inplace=True)
    X_test_red.reset_index(drop=True, inplace=True)
    for _ in range(n_bs):
        if _ % 500 == 0:
            print(f'bootstrap sample {_}')
        idx_blue = np.random.choice(n_blue, n_blue, replace=True)
        idx_red = np.random.choice(n_red, n_red, replace=True)
        X_test_blue_bs = X_test_blue.iloc[idx_blue]
        y_test_blue_bs = y_test_blue[idx_blue]
        X_test_red_bs = X_test_red.iloc[idx_red]
        y_test_red_bs = y_test_red[idx_red]
        res_blue = est_group_loss(clf_b, X_test_blue_bs, y_test_blue_bs,
                                  X_test_red_bs, y_test_red_bs)
        e_b_B_list.append(res_blue['e_b'])
        e_r_B_list.append(res_blue['e_r'])
        res_red = est_group_loss(clf_r, X_test_blue_bs, y_test_blue_bs,
                                 X_test_red_bs, y_test_red_bs)
        e_b_R_list.append(res_red['e_b'])
        e_r_R_list.append(res_red['e_r'])
    e_b_B_list = np.array(e_b_B_list)
    e_r_B_list = np.array(e_r_B_list)
    e_b_R_list = np.array(e_b_R_list)
    e_r_R_list = np.array(e_r_R_list)
    
    # compute p-values
    b_diff_list = np.abs(e_b_R_list - e_b_B_list) - delta
    r_diff_list = np.abs(e_r_B_list - e_r_R_list) - delta
    b_diff = np.abs(e_b_R - e_b_B) - delta
    r_diff = np.abs(e_r_B - e_r_R) - delta
    
    p_b = np.mean(b_diff_list - b_diff <= b_diff)
    p_r = np.mean(r_diff_list - r_diff <= r_diff)
    p_max = np.max([p_b, p_r])
    
    print(f'p_b: {p_b}, p_r: {p_r}, p_max: {p_max}')

    return p_b, p_r, e_b_B_list, e_r_B_list, e_b_R_list, e_r_R_list


def CV_classification(df, Y_column, G_column, X_columns, n_splits=5, seed=42,
                      ratio=1.0, clf=None, n_jobs=None, verbose=True,
                      standardize=False, train_frac=1.0):
    """Cross-validation for classification

    Parameters
    ----------
    G_column : str
        group variable
    X_columns : list of str
        covariates
    Y_column : str
        outcome variable
    df : pandas.DataFrame
        whole dataset
    seed : int, optional
        for randomization, by default 42
    ratio : float, optional
        ratio of the large group to the small group, by default 1.0
    clf : sklearn.base.BaseEstimator, optional
        classifier from scikit-learn, by default None
    train_frac : float, optional
        train the model using only a subset of the training set, by default 1.0

    Returns
    -------
    res : dict
        a dictionary of the results
    """
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
    df = df.reset_index(drop=True, inplace=False)
    e_b_blue_list = []
    e_r_blue_list = []
    e_b_red_list = []
    e_r_red_list = []

    if clf is None:
        clf = RandomForestClassifier(n_estimators=300, max_features='sqrt',
                                     random_state=seed, n_jobs=n_jobs)

    for train_index, test_index in kf.split(df):
        X_test_blue, y_test_blue, X_test_red, y_test_red, train =\
            split_sample_cv(G_column, X_columns, Y_column, df,
                            train_index, test_index,
                            standardize=standardize)
        
        if train_frac < 1.0:
            train_blue = train[train[G_column]==1]
            train_red = train[train[G_column]==0]
            train_blue = train_blue.sample(frac=train_frac, random_state=seed)
            train_red = train_red.sample(frac=train_frac, random_state=seed)
            train = pd.concat([train_blue, train_red], axis=0)
            train = train.reset_index(drop=True)
        
        X_train_blue_full, X_train_red_full,\
        y_train_blue_full, y_train_red_full =\
            balance_train_size(train, G_column, X_columns, Y_column, ratio,
                               verbose=False)

        # blue
        clf.fit(X_train_blue_full, y_train_blue_full)
        res_blue =\
            est_group_loss(clf, X_test_blue, y_test_blue,
                           X_test_red, y_test_red)
        e_b = res_blue['e_b']
        e_r = res_blue['e_r']
        e_b_blue_list.append(e_b)
        e_r_blue_list.append(e_r)

        # red
        clf.fit(X_train_red_full, y_train_red_full)
        res_red = est_group_loss(clf, X_test_blue, y_test_blue,
                                 X_test_red, y_test_red)
        e_b = res_red['e_b']
        e_r = res_red['e_r']
        e_b_red_list.append(e_b)
        e_r_red_list.append(e_r)

    e_b_blue = np.mean(e_b_blue_list)
    e_r_blue = np.mean(e_r_blue_list)
    e_b_red = np.mean(e_b_red_list)
    e_r_red = np.mean(e_r_red_list)
    res_balance = eval_balance(e_b_blue, e_r_blue, e_b_red, e_r_red)

    # print out the results
    if verbose:
        print()
        print(f'blue: {G_column}=1')
        print(f'[blue] e_b={e_b_blue.round(3)}, e_r={e_r_blue.round(3)}, fairness={np.abs(e_b_blue - e_r_blue).round(3)}')
        print(f'[red] e_b={e_b_red.round(3)}, e_r={e_r_red.round(3)}, fairness={np.abs(e_b_red - e_r_red).round(3)}')
        print(res_balance)
    
    res = {'e_b_blue': e_b_blue_list, 'e_r_blue': e_r_blue_list,
           'e_b_red': e_b_red_list, 'e_r_red': e_r_red_list,
           'balance': res_balance}
    return res


def est_group_loss(model, X_test_blue, y_test_blue, X_test_red, y_test_red):
    '''estimate group losses for classification'''
    # y_pred below is the predicted class labels, not probabilities
    y_pred_blue = model.predict(X_test_blue)
    y_pred_red = model.predict(X_test_red)
    
    accuracy_blue = accuracy_score(y_test_blue, y_pred_blue)
    accuracy_red = accuracy_score(y_test_red, y_pred_red)
    e_b = 1 - accuracy_blue
    e_r = 1 - accuracy_red
    cm_blue = confusion_matrix(y_test_blue, y_pred_blue)
    cm_red = confusion_matrix(y_test_red, y_pred_red)
    res = {}
    res['e_b'] = e_b
    res['e_r'] = e_r
    res['cm_blue'] = cm_blue
    res['cm_red'] = cm_red
    return res


def no_info_utilitarian(df, Y_column, G_column, X_columns, standardize=False,
                        n_splits=5, seed=42, verbose=True):
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
    df = df.reset_index(drop=True, inplace=False)
    e_b_list = []
    e_r_list = []
    
    counter = 0
    for train_index, test_index in kf.split(df):
        if verbose:
            print(f'iteration: {counter}')
        counter += 1
        df_X_test_blue, df_y_test_blue, df_X_test_red, df_y_test_red, train =\
            split_sample_cv(G_column, X_columns, Y_column, df,
                            train_index, test_index,
                            standardize=standardize)
        y_train = train[Y_column].copy()
        y_train = y_train.values
        n_1 = np.sum(y_train==1)
        n_0 = np.sum(y_train==0)
        if n_1 > n_0:
            d = 1
        else:
            d = 0
        y_test_blue = df_y_test_blue
        y_test_red = df_y_test_red
        e_b = np.mean(y_test_blue != d)
        e_r = np.mean(y_test_red != d)
        e_b_list.append(e_b)
        e_r_list.append(e_r)

    e_b = np.mean(e_b_list)
    e_r = np.mean(e_r_list)
    
    res = {}
    res['e_b'] = e_b
    res['e_r'] = e_r
    res['e_b_list'] = e_b_list
    res['e_r_list'] = e_r_list
    
    return res, y_test_blue


def bs_p_value(clf_b, clf_r, X_test_blue, y_test_blue, X_test_red, y_test_red,
               e_b_B, e_r_B, e_b_R, e_r_R, n_bs=1000, seed=42):
    """Bootstrap p-value for the difference between two classifiers

    Parameters
    ----------
    clf_b, clf_r : sklearn.base.BaseEstimator
        classifier from scikit-learn
    X_test_blue, y_test_blue, X_test_red, y_test_red : np.array
        test sets
        X's are pandas dataframes, y's are numpy arrays
    n_bs : int, optional
        number of bootstrap samples, by default 1000
    seed : int, optional
        for randomization, by default 42
    e_b_B, e_r_B, e_b_R, e_r_R : float
        estimated group losses

    Returns
    -------
    p_b, p_r : float
        p-values
    """
    np.random.seed(seed)
    n_blue = X_test_blue.shape[0]
    n_red = X_test_red.shape[0]
    e_b_B_list = []
    e_r_B_list = []
    e_b_R_list = []
    e_r_R_list = []
    X_test_blue.reset_index(drop=True, inplace=True)
    X_test_red.reset_index(drop=True, inplace=True)
    for _ in range(n_bs):
        if _ % 500 == 0:
            print(f'bootstrap sample {_}')
        idx_blue = np.random.choice(n_blue, n_blue, replace=True)
        idx_red = np.random.choice(n_red, n_red, replace=True)
        X_test_blue_bs = X_test_blue.iloc[idx_blue]
        y_test_blue_bs = y_test_blue[idx_blue]
        X_test_red_bs = X_test_red.iloc[idx_red]
        y_test_red_bs = y_test_red[idx_red]
        res_blue = est_group_loss(clf_b, X_test_blue_bs, y_test_blue_bs,
                                  X_test_red_bs, y_test_red_bs)
        e_b_B_list.append(res_blue['e_b'])
        e_r_B_list.append(res_blue['e_r'])
        res_red = est_group_loss(clf_r, X_test_blue_bs, y_test_blue_bs,
                                 X_test_red_bs, y_test_red_bs)
        e_b_R_list.append(res_red['e_b'])
        e_r_R_list.append(res_red['e_r'])
    e_b_B_list = np.array(e_b_B_list)
    e_r_B_list = np.array(e_r_B_list)
    e_b_R_list = np.array(e_b_R_list)
    e_r_R_list = np.array(e_r_R_list)
    np.save('e_b_B_list_{seed}_{n_bs}.npy', e_b_B_list)
    np.save('e_r_B_list_{seed}_{n_bs}.npy', e_r_B_list)
    np.save('e_b_R_list_{seed}_{n_bs}.npy', e_b_R_list)
    np.save('e_r_R_list_{seed}_{n_bs}.npy', e_r_R_list)
    
    # compute p-values
    b_skewed_list = e_b_R_list - e_r_R_list
    r_skewed_list = e_r_B_list - e_b_B_list
    b_skewed = e_b_R - e_r_R
    r_skewed = e_r_B - e_b_B
    
    p_b = np.mean(b_skewed_list - b_skewed >= b_skewed)
    p_r = np.mean(r_skewed_list - r_skewed >= r_skewed)
    p_max = np.max([p_b, p_r])
    
    print(f'p_b: {p_b}, p_r: {p_r}, p_max: {p_max}')

    return p_b, p_r


def bs_p_value_skewness(clf_r, X_test_blue, y_test_blue, X_test_red, y_test_red,
                        e_b_R, e_r_R, n_bs=10000, seed=42):
    """Bootstrap p-value for testing b-skewness

    Parameters
    ----------
    clf_r : sklearn.base.BaseEstimator
        classifier from scikit-learn
    X_test_blue, y_test_blue, X_test_red, y_test_red : np.array
        test sets
        X's are pandas dataframes, y's are numpy arrays
    n_bs : int, optional
        number of bootstrap samples, by default 1000
    seed : int, optional
        for randomization, by default 42
    e_b_R, e_r_R : float
        estimated group losses

    Returns
    -------
    p_b : float
        p-value
    """
    np.random.seed(seed)
    n_blue = X_test_blue.shape[0]
    n_red = X_test_red.shape[0]
    e_b_R_list = []
    e_r_R_list = []
    X_test_blue.reset_index(drop=True, inplace=True)
    X_test_red.reset_index(drop=True, inplace=True)
    for _ in range(n_bs):
        idx_blue = np.random.choice(n_blue, n_blue, replace=True)
        idx_red = np.random.choice(n_red, n_red, replace=True)
        X_test_blue_bs = X_test_blue.iloc[idx_blue]
        y_test_blue_bs = y_test_blue[idx_blue]
        X_test_red_bs = X_test_red.iloc[idx_red]
        y_test_red_bs = y_test_red[idx_red]

        res_red = est_group_loss(clf_r, X_test_blue_bs, y_test_blue_bs,
                                 X_test_red_bs, y_test_red_bs)
        e_b_R_list.append(res_red['e_b'])
        e_r_R_list.append(res_red['e_r'])
    e_b_R_list = np.array(e_b_R_list)
    e_r_R_list = np.array(e_r_R_list)
    # np.save('fullcov_e_b_R_list_{seed}_{n_bs}.npy', e_b_R_list)
    # np.save('fullcov_e_r_R_list_{seed}_{n_bs}.npy', e_r_R_list)
    
    # compute p-values
    b_skewed_list = e_r_R_list - e_b_R_list
    b_skewed = e_r_R - e_b_R
    p_b = np.mean(b_skewed_list - b_skewed > b_skewed)
    
    print(f'p: {p_b}')

    return p_b
